from matplotlib.animation import FuncAnimation
plot_data_frames = []
for g, gdf in merged_df[merged_df['STATE'] == '06'].groupby(['date']):
plot_data_frames.append(gdf.sort_values('totalcountconfirmed')[-10:])
counties = sorted(merged_df[merged_df['STATE'] == '06']['county'].unique())
colors = {county: color for county, color in zip(counties, sns.color_palette("tab20", n_colors=len(counties)))}
def get_colors(counties):
return [colors[c] for c in counties]
def nice_axes(ax):
ax.tick_params(labelsize=8, length=0)
ax.set_axisbelow(True)
[spine.set_visible(False) for spine in ax.spines.values()]
def init():
ax.clear()
nice_axes(ax)
# ax.set_ylim(.2, 6.8)
def update(i):
for bar in ax.containers:
bar.remove()
tmp_df = plot_data_frames[i]
ax.barh(
y=[n+0.5 for n in range(10)],
width=tmp_df['totalcountconfirmed'],
tick_label=tmp_df['county'],
color=get_colors(tmp_df['county'].values)
)
date = tmp_df.iloc[0]['date']
ax.set_title(f'COVID-19 Total Cases by County - {date}')
fig = plt.Figure(figsize=(12, 8))
ax = fig.add_subplot()
anim = FuncAnimation(
fig=fig,
func=update,
init_func=init,
frames=len(plot_data_frames),
interval=100,
repeat=False
)
display(HTML(anim.to_jshtml()))